{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解读 FoldScaleAxis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```cpp\n",
    "/*!\n",
    " * \\brief Backward fold axis scaling into weights of conv/dense operators.\n",
    " *\n",
    " * \\return The pass.\n",
    " */\n",
    "TVM_DLL Pass BackwardFoldScaleAxis();\n",
    "\n",
    "/*!\n",
    " * \\brief Forward fold axis scaling into weights of conv/dense operators.\n",
    " *\n",
    " * \\return The pass.\n",
    " */\n",
    "TVM_DLL Pass ForwardFoldScaleAxis();\n",
    "\n",
    "/*!\n",
    " * \\brief A sequential pass that executes ForwardFoldScaleAxis and\n",
    " * BackwardFoldScaleAxis passes.\n",
    " *\n",
    " * \\return The pass.\n",
    " */\n",
    "TVM_DLL Pass FoldScaleAxis();\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这三个函数是用于在卷积/dense运算中折叠轴缩放的传递。它们分别是：\n",
    "\n",
    "1. `BackwardFoldScaleAxis()`：后向折叠轴缩放。这个函数的目的是在反向传播过程中，将轴缩放运算折叠到卷积/dense运算的权重中。这样可以减少计算量，提高性能。\n",
    "\n",
    "2. `ForwardFoldScaleAxis()`：前向折叠轴缩放。这个函数的目的是在前向传播过程中，将轴缩放运算折叠到卷积/dense运算的权重中。这样可以在不改变模型输出的情况下，减少计算量，提高性能。\n",
    "\n",
    "3. `FoldScaleAxis()`：这是一个顺序传递，它会依次执行 `ForwardFoldScaleAxis()` 和 `BackwardFoldScaleAxis()` 两个传递。这样可以同时优化前向和反向传播过程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relay.transform import FoldScaleAxis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m \u001b[0mFoldScaleAxis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m   \n",
      "\u001b[0;32mdef\u001b[0m \u001b[0mFoldScaleAxis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m\"\"\"Fold the scaling of axis into weights of conv2d/dense. This pass will\u001b[0m\n",
      "\u001b[0;34m    invoke both forward and backward scale folding.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Returns\u001b[0m\n",
      "\u001b[0;34m    -------\u001b[0m\n",
      "\u001b[0;34m    ret : tvm.transform.Pass\u001b[0m\n",
      "\u001b[0;34m        The registered pass to fold expressions.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Note\u001b[0m\n",
      "\u001b[0;34m    ----\u001b[0m\n",
      "\u001b[0;34m    Internally, we will call backward_fold_scale_axis before using\u001b[0m\n",
      "\u001b[0;34m    forward_fold_scale_axis as backward folding targets the common conv->bn\u001b[0m\n",
      "\u001b[0;34m    pattern.\u001b[0m\n",
      "\u001b[0;34m    \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;32mreturn\u001b[0m \u001b[0m_ffi_api\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFoldScaleAxis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/relay/transform/transform.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "FoldScaleAxis??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "函数的内部注释说明了其工作原理：在内部，先调用 `backward_fold_scale_axis`，然后再使用 `forward_fold_scale_axis`。这是因为后向折叠针对的是常见的卷积->批量标准化（Conv->BN）模式。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
